-
-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use SparseDiffTools v2 for steadystateadjoint #808
Conversation
TODO- update OrdinaryDiffEq to use sparsedifftools v2 |
@@ -98,8 +100,7 @@ end | |||
end | |||
|
|||
if !needs_jac | |||
# TODO: FixedVecJacOperator should respect the `autojacvec` of the algorithm | |||
operator = FixedVecJacOperator(f, y, p, Val(DiffEqBase.isinplace(sol.prob))) | |||
operator = VecJac(f, y, p; autodiff = get_autodiff_from_vjp(vjp)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not really equivalent. IIRC VecJac
recomputes the pullback everytime a call to mul!
is made. In this case, we have a fixed input, only the seeding changes so we compute the pullback once and just reevaluate it multiple times.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this actually fixed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, see JuliaDiff/SparseDiffTools.jl#245
retriggering CI |
We should add a test case to ensure that |
Does this work yet? I was using LinearSolve.jl, and I think the adjoints are not implemented yet. I need it (quite desperately), so I will probably implement something by next week. using Zygote, ForwardDiff, SciMLSensitivity, SciMLBase, LinearSolve, ComponentArrays,
FiniteDiff
function loss_function(θ)
(; A, b) = θ
prob = LinearProblem(A, b)
sol = solve(prob, nothing)
return sum(sol.u)
end
function loss_function_chainrules(θ)
(; A, b) = θ
x = A \ b
return sum(x)
end
A = Float32[1 0; 1 -2]; b = Float32[32; -4];
θ = ComponentArray(; A, b)
loss_function(θ) ≈ loss_function_chainrules(θ) # true
Zygote.gradient(loss_function, θ) # fails
Zygote.gradient(loss_function_chainrules, θ) # works
ForwardDiff.gradient(loss_function, θ) # fails
ForwardDiff.gradient(loss_function_chainrules, θ) # works
FiniteDiff.finite_difference_gradient(loss_function, θ) # works |
with this branch, your example is erroring on some ERROR: LoadError: Compiling Tuple{LinearSolve.var"##solve#32", Base.Pairs{Symbol, Union
{}, Tuple{}, NamedTuple{(), Tuple{}}}, typeof(solve), LinearSolve.LinearCache{Base.Resh
apedArray{Float32, 2, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, tr
ue}, Tuple{}}, SubArray{Float32, 1, Vector{Float32}, Tuple{UnitRange{Int64}}, true}, Ve
ctor{Float32}, SciMLBase.NullParameters, KrylovJL{typeof(Krylov.gmres!), Int64, Tuple{}
, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}}, Krylov.GmresSolver{Fl
oat32, Float32, Vector{Float32}}, SciMLOperators.IdentityOperator, SciMLOperators.Ident
ityOperator, Float32, true, LinearSolve.OperatorCondition.IllConditioned}, KrylovJL{typ
eof(Krylov.gmres!), Int64, Tuple{}, Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(),
Tuple{}}}}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
https://fluxml.ai/Zygote.jl/latest/limitations ``` |
Yes, ik I gave that as a testcase that doesn't work. The linear problem doesn't have tests because the adjoints aren't implemented. Also, taking a look at how linearsolve works, it might a bit refactoring before we can include adjoints since it dispatches on solve instead of __solve |
We can make a higher level in SciMLBase and dispatch it on __solve. |
We can get this PR merged, right? The linear solve issue is entirely tangential and needs to be handled downstream first. |
someone needs to resolve the merge. |
No Enzyme or ReverseDiff though? |
part of: SciML/SciMLOperators.jl#142